Delete node in a BST

Time: O(H); Space: O(H); medium

Given a root node reference of a BST and a key, delete the node with the given key in the BST.

Return the root node reference (possibly updated) of the BST.

Basically, the deletion can be divided into two stages: 1. Search for a node to remove. 2. If the node is found, delete the node.

Note:

  • Time complexity should be O(height of tree).

Example 1:

    5
   / \
  3   6
 / \   \
2   4   7

Input: root = {TreeNode} [5,3,6,2,4,null,7], key = 3

Output:

  1. One valid answer is [5,4,6,2,null,null,7], shown in the following BST.

        5
       / \
      4   6
     /     \
    2       7
    
  2. Another valid answer is [5,2,6,null,4,null,7].

      5
     / \
    2   6
     \   \
      4   7
    

Explanation:

  • Given key to delete is 3. So we find the node with value 3 and delete it.

[1]:
class TreeNode(object):
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None
[2]:
class Solution1(object):
    """
    Time: O(H)
    Space: O(H)
    """
    def deleteNode(self, root, key):
        """
        :type root: TreeNode
        :type key: int
        :rtype: TreeNode
        """
        if not root:
            return root

        if root.val > key:
            root.left = self.deleteNode(root.left, key)
        elif root.val < key:
            root.right = self.deleteNode(root.right, key)
        else:
            if not root.left:
                right = root.right
                del root
                return right
            elif not root.right:
                left = root.left
                del root
                return left
            else:
                successor = root.right
                while successor.left:
                    successor = successor.left

                root.val = successor.val
                root.right = self.deleteNode(root.right, successor.val)

        return root
[8]:
s = Solution1()

root = TreeNode(5)
root.left = TreeNode(3)
root.right = TreeNode(6)
root.left.left = TreeNode(2)
root.left.right = TreeNode(4)
root.right.right = TreeNode(7)
key = 3
s.deleteNode(root, key)
assert root.val == 5
assert root.left.val == 4
assert root.right.val == 6
assert root.left.left.val == 2
assert root.right.right.val == 7